aboutsummaryrefslogtreecommitdiff
path: root/cmd/mybittorrent
diff options
context:
space:
mode:
Diffstat (limited to 'cmd/mybittorrent')
-rw-r--r--cmd/mybittorrent/client.go148
-rw-r--r--cmd/mybittorrent/main.go32
-rw-r--r--cmd/mybittorrent/meta.go55
-rw-r--r--cmd/mybittorrent/peer.go69
4 files changed, 259 insertions, 45 deletions
diff --git a/cmd/mybittorrent/client.go b/cmd/mybittorrent/client.go
index 8b3aac8..a450a81 100644
--- a/cmd/mybittorrent/client.go
+++ b/cmd/mybittorrent/client.go
@@ -3,6 +3,7 @@ package main
import (
"bytes"
"encoding/binary"
+ "errors"
"fmt"
"io"
"log/slog"
@@ -11,6 +12,7 @@ import (
"net/url"
"os"
"strconv"
+ "sync"
bencode "github.com/jackpal/bencode-go"
)
@@ -22,10 +24,12 @@ type Client struct {
}
type ClientTorrent struct {
- Meta Meta
- Uploaded int
- Downloaded int
- Left int
+ Meta Meta
+ Uploaded int
+ Downloaded int
+ Left int
+ PeerResponse PeerResponse
+ Peers []*Peer
}
type PeerResponse struct {
@@ -33,6 +37,50 @@ type PeerResponse struct {
Peers []string
}
+type FileResult struct {
+ Data []byte
+ Piece Piece
+}
+
+type FileWriter struct {
+ ch chan FileResult
+ piece Piece
+}
+
+func (c *Client) ConnectPeers(filename string) (*ClientTorrent, error) {
+ ct, ok := c.Torrents[filename]
+ if !ok {
+ if err := c.AddTorrentFile(filename); err != nil {
+ return nil, err
+ }
+ }
+
+ if len(ct.PeerResponse.Peers) == 0 {
+ if _, err := c.GetPeers(filename); err != nil {
+ return nil, fmt.Errorf("failed to get peers: %v+", err)
+ }
+ }
+
+ for _, peerAddr := range ct.PeerResponse.Peers {
+ peer, err := c.Handshake(filename, peerAddr)
+ if err != nil {
+ continue
+ }
+
+ ct.Peers = append(ct.Peers, peer)
+ }
+
+ return ct, nil
+}
+
+func (c *Client) Close() (err error) {
+ for _, ct := range c.Torrents {
+ err = errors.Join(err, ct.Close())
+ }
+
+ return
+}
+
func NewClient(peerId string, port int) *Client {
return &Client{
PeerId: peerId,
@@ -70,7 +118,7 @@ func (c *Client) AddTorrentFile(filename string) error {
return nil
}
-func (ct ClientTorrent) getUrl(c Client) (string, error) {
+func (ct *ClientTorrent) getUrl(c Client) (string, error) {
u, err := url.Parse(ct.Meta.Announce)
if err != nil {
return "", err
@@ -130,10 +178,12 @@ func (c *Client) GetPeers(filename string) (PeerResponse, error) {
return PeerResponse{}, err
}
- return PeerResponse{
+ ct.PeerResponse = PeerResponse{
Interval: resp.Interval,
Peers: DecodePeers([]byte(resp.Peers)),
- }, err
+ }
+
+ return ct.PeerResponse, nil
}
func (c *Client) Handshake(filename, peerAddr string) (*Peer, error) {
@@ -187,3 +237,87 @@ func (c *Client) Handshake(filename, peerAddr string) (*Peer, error) {
return peer, nil
}
+
+func (f *FileResult) WriteTo(w io.Writer) (int64, error) {
+ n, err := w.Write(f.Data)
+ return int64(n), err
+}
+
+func (f *FileWriter) Write(b []byte) (int, error) {
+
+ f.ch <- FileResult{
+
+ Data: b,
+
+ Piece: f.piece,
+ }
+
+ return len(b), nil
+
+}
+
+func (ct *ClientTorrent) Download(out string) error {
+ pieces := ct.Meta.Pieces()
+ slog.Debug("Starting download", "pieceCnt", len(pieces), "peers", ct.PeerResponse)
+ pieceCh := make(chan Piece, len(pieces))
+ fileCh := make(chan FileResult)
+ wg := sync.WaitGroup{}
+ wg.Add(len(pieces))
+
+ for _, piece := range pieces {
+ pieceCh <- piece
+ }
+
+ for _, peer := range ct.Peers {
+ go peer.Download(pieceCh, fileCh)
+ }
+
+ go func() {
+ wg.Wait()
+ close(fileCh)
+ close(pieceCh)
+ }()
+
+ return ct.WriteFile(pieceCh, fileCh, out, &wg)
+}
+
+func (ct *ClientTorrent) Close() (err error) {
+ for _, peer := range ct.Peers {
+ err = errors.Join(err, peer.Close())
+ }
+
+ return
+}
+
+func (ct *ClientTorrent) WriteFile(pieceCh chan Piece, fileCh chan FileResult, out string, wg *sync.WaitGroup) error {
+ f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY, 0644)
+ if err != nil {
+ return err
+ }
+
+ defer f.Close()
+
+ for fr := range fileCh {
+ if _, err = f.Seek(int64(fr.Piece.Index*ct.Meta.Info.PieceLength), io.SeekStart); err != nil {
+ slog.Error("failed to seek", "err", err)
+ pieceCh <- fr.Piece
+ continue
+ }
+
+ if _, err = fr.WriteTo(f); err != nil {
+ slog.Error("failed to write piece to file", "err", err)
+ pieceCh <- fr.Piece
+ continue
+ }
+
+ if err = f.Sync(); err != nil {
+ slog.Error("failed to sync file to disk", "err", err)
+ continue
+ }
+
+ slog.Debug("writing to file", "piece index", fr.Piece.Index)
+ wg.Done()
+ }
+
+ return nil
+}
diff --git a/cmd/mybittorrent/main.go b/cmd/mybittorrent/main.go
index 80a9b7d..6dbdda6 100644
--- a/cmd/mybittorrent/main.go
+++ b/cmd/mybittorrent/main.go
@@ -75,6 +75,9 @@ func main() {
peerAddr := os.Args[3]
peer, err := c.Handshake(fn, peerAddr)
+
+ defer peer.Close()
+
if err != nil {
panic(err)
}
@@ -91,6 +94,7 @@ func main() {
}
c := createClient(fn)
+ defer c.Close()
pr, err := c.GetPeers(fn)
if err != nil {
panic(err)
@@ -109,15 +113,39 @@ func main() {
panic(fmt.Errorf("no peers found for file: %s", out))
}
- defer peer.Close()
+ f, err := os.Create(out)
+ if err != nil {
+ panic(f)
+ }
- err = peer.DownloadPiece(out, index)
+ defer f.Close()
+
+ err = peer.DownloadPiece(f, c.Torrents[fn].Meta.Pieces()[index])
if err != nil {
panic(err)
}
fmt.Printf("Piece %d downloaded to %s.", index, out)
+ case "download":
+ out := os.Args[3]
+ fn := os.Args[4]
+ c := createClient(fn)
+
+ defer c.Close()
+
+ ct, err := c.ConnectPeers(fn)
+ if err != nil {
+ panic(err)
+ }
+
+ err = ct.Download(out)
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Printf("Downloaded %s to %s", fn, out)
+
default:
fmt.Println("Unknown command: " + command)
os.Exit(1)
diff --git a/cmd/mybittorrent/meta.go b/cmd/mybittorrent/meta.go
index 27bcc12..11ade01 100644
--- a/cmd/mybittorrent/meta.go
+++ b/cmd/mybittorrent/meta.go
@@ -3,6 +3,7 @@ package main
import (
"bytes"
"crypto/sha1"
+ "fmt"
"math"
bencode "github.com/jackpal/bencode-go"
@@ -22,6 +23,13 @@ type FileInfo struct {
Pieces string `bencode:"pieces"`
}
+type Piece struct {
+ Index int
+ Len int
+ Hash string
+ Blocks []uint32
+}
+
func (m Meta) InfoHash() ([]byte, error) {
sha := sha1.New()
if err := bencode.Marshal(sha, m.Info); err != nil {
@@ -40,16 +48,6 @@ func (m Meta) PieceHashes() []string {
return hashes
}
-func (m Meta) CheckHash(pieceIndex int, data []byte) bool {
- sha := sha1.New()
-
- if _, err := bytes.NewBuffer(data).WriteTo(sha); err != nil {
- return false
- }
-
- return bytes.Equal([]byte(m.Info.Pieces[pieceIndex*20:pieceIndex*20+20]), sha.Sum(nil))
-}
-
func (m Meta) PieceCount() int {
return len(m.Info.Pieces) / 20
}
@@ -84,3 +82,40 @@ func (m Meta) BlockLens(pieceIdx int) []uint32 {
return blocks
}
+
+func (p Piece) CheckHash(data []byte) error {
+ sha := sha1.New()
+ if _, err := bytes.NewBuffer(data).WriteTo(sha); err != nil {
+ return err
+ }
+
+ exp, act := []byte(p.Hash), sha.Sum(nil)
+ if !bytes.Equal(exp, act) {
+ return fmt.Errorf("expected hash: %x, actual: %x", exp, act)
+ }
+
+ return nil
+}
+
+func (m Meta) Pieces() []Piece {
+ cnt := m.PieceCount()
+ pieces := make([]Piece, cnt)
+
+ for i := 0; i < cnt; i++ {
+ p := Piece{
+ Index: i,
+ Hash: m.Info.Pieces[i*20 : i*20+20],
+ Blocks: m.BlockLens(i),
+ }
+
+ if i < cnt-1 {
+ p.Len = m.Info.PieceLength
+ } else {
+ p.Len = m.Info.Length - i*m.Info.PieceLength
+ }
+
+ pieces[i] = p
+ }
+
+ return pieces
+}
diff --git a/cmd/mybittorrent/peer.go b/cmd/mybittorrent/peer.go
index 3b40493..aee5f26 100644
--- a/cmd/mybittorrent/peer.go
+++ b/cmd/mybittorrent/peer.go
@@ -7,8 +7,8 @@ import (
"errors"
"fmt"
"io"
+ "log/slog"
"net"
- "os"
)
type HandshakeMessage []byte
@@ -17,7 +17,7 @@ type Peer struct {
conn net.Conn
handshake HandshakeMessage
ct *ClientTorrent
- msgCh chan *IncomingMessage
+ unchoked bool
}
func (p *Peer) PeerIdHexString() string {
@@ -37,31 +37,36 @@ func (p *Peer) Close() error {
return nil
}
-func (p *Peer) DownloadPiece(outFile string, index int) error {
- if index >= p.ct.Meta.PieceCount() {
+func (p *Peer) DownloadPiece(w io.Writer, piece Piece) error {
+ if piece.Index >= p.ct.Meta.PieceCount() {
return nil
}
var data = new(bytes.Buffer)
- blockLens := p.ct.Meta.BlockLens(index)
var blockIndex = 0
requestFn := func() error {
- if blockIndex == len(blockLens) {
+ if blockIndex == len(piece.Blocks) {
return nil
}
r := RequestPayload{
- Index: uint32(index),
+ Index: uint32(piece.Index),
Begin: uint32(blockIndex * BlockSize),
- Length: blockLens[blockIndex],
+ Length: piece.Blocks[blockIndex],
}
return p.WriteMessage(MessageTypeRequest, r.Bytes())
}
+ if p.unchoked {
+ if err := requestFn(); err != nil {
+ return err
+ }
+ }
+
for {
- if blockIndex == len(blockLens) {
+ if blockIndex == len(piece.Blocks) {
break
}
@@ -85,6 +90,7 @@ func (p *Peer) DownloadPiece(outFile string, index int) error {
if err = requestFn(); err != nil {
return err
}
+ p.unchoked = true
case MessageTypePiece:
var block BlockPayload
@@ -105,31 +111,24 @@ func (p *Peer) DownloadPiece(outFile string, index int) error {
return err
}
+ case MessageTypeChoke:
+ p.unchoked = false
+
default:
return fmt.Errorf("unimplemented message type: %d", msg.MessageType)
}
}
- if !p.ct.Meta.CheckHash(index, data.Bytes()) {
- return fmt.Errorf("invalid hash value")
- }
-
- out, err := os.Create(outFile)
- if err != nil {
- return errors.Join(fmt.Errorf("failed to create file"), err)
- }
+ slog.Debug("piece downloaded", "piece index", piece.Index)
- _, err = out.Write(data.Bytes())
- if err != nil {
- return errors.Join(fmt.Errorf("failed to write file"), err)
+ if err := piece.CheckHash(data.Bytes()); err != nil {
+ return fmt.Errorf("invalid hash value: %v", err)
}
- err = out.Close()
- if err != nil {
- return errors.Join(fmt.Errorf("failed to close file"), err)
- }
-
- return nil
+ slog.Debug("hash ok", "piece index", piece.Index)
+ _, err := w.Write(data.Bytes())
+ slog.Debug("written to writer", "piece index", piece.Index)
+ return err
}
func (p *Peer) ReadMessage() (*IncomingMessage, error) {
@@ -179,3 +178,21 @@ func (p *Peer) WriteMessage(t MessageType, payload []byte) error {
_, err := msg.Write(payload)
return err
}
+
+func (p *Peer) Download(pieceCh chan Piece, fileCh chan FileResult) {
+ for piece := range pieceCh {
+ fw := FileWriter{
+ ch: fileCh,
+ piece: piece,
+ }
+
+ err := p.DownloadPiece(&fw, piece)
+ if err != nil {
+ slog.Error("failed to download piece", "piece", piece, "err", err)
+ pieceCh <- piece
+ continue
+ }
+
+ slog.Debug("downloaded piece", "piece index", piece.Index, "left", len(pieceCh))
+ }
+}