diff options
| author | jet2tlf <jet2tlf@gmail.com> | 2024-06-03 18:31:42 +0000 |
|---|---|---|
| committer | jet2tlf <jet2tlf@gmail.com> | 2024-06-03 18:31:42 +0000 |
| commit | 210fb1e02453413d1ce070b70c850807286a1a7a (patch) | |
| tree | dbc49ba086460dfbf62ef1d2d602cd8da46e6df2 | |
| parent | 853be358804a6e30e857035ffda81a06df3f6b74 (diff) | |
| download | bittorrent-go-210fb1e02453413d1ce070b70c850807286a1a7a.tar.gz bittorrent-go-210fb1e02453413d1ce070b70c850807286a1a7a.zip | |
| -rw-r--r-- | cmd/mybittorrent/client.go | 148 | ||||
| -rw-r--r-- | cmd/mybittorrent/main.go | 32 | ||||
| -rw-r--r-- | cmd/mybittorrent/meta.go | 55 | ||||
| -rw-r--r-- | cmd/mybittorrent/peer.go | 69 |
4 files changed, 259 insertions, 45 deletions
diff --git a/cmd/mybittorrent/client.go b/cmd/mybittorrent/client.go index 8b3aac8..a450a81 100644 --- a/cmd/mybittorrent/client.go +++ b/cmd/mybittorrent/client.go @@ -3,6 +3,7 @@ package main import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "log/slog" @@ -11,6 +12,7 @@ import ( "net/url" "os" "strconv" + "sync" bencode "github.com/jackpal/bencode-go" ) @@ -22,10 +24,12 @@ type Client struct { } type ClientTorrent struct { - Meta Meta - Uploaded int - Downloaded int - Left int + Meta Meta + Uploaded int + Downloaded int + Left int + PeerResponse PeerResponse + Peers []*Peer } type PeerResponse struct { @@ -33,6 +37,50 @@ type PeerResponse struct { Peers []string } +type FileResult struct { + Data []byte + Piece Piece +} + +type FileWriter struct { + ch chan FileResult + piece Piece +} + +func (c *Client) ConnectPeers(filename string) (*ClientTorrent, error) { + ct, ok := c.Torrents[filename] + if !ok { + if err := c.AddTorrentFile(filename); err != nil { + return nil, err + } + } + + if len(ct.PeerResponse.Peers) == 0 { + if _, err := c.GetPeers(filename); err != nil { + return nil, fmt.Errorf("failed to get peers: %v+", err) + } + } + + for _, peerAddr := range ct.PeerResponse.Peers { + peer, err := c.Handshake(filename, peerAddr) + if err != nil { + continue + } + + ct.Peers = append(ct.Peers, peer) + } + + return ct, nil +} + +func (c *Client) Close() (err error) { + for _, ct := range c.Torrents { + err = errors.Join(err, ct.Close()) + } + + return +} + func NewClient(peerId string, port int) *Client { return &Client{ PeerId: peerId, @@ -70,7 +118,7 @@ func (c *Client) AddTorrentFile(filename string) error { return nil } -func (ct ClientTorrent) getUrl(c Client) (string, error) { +func (ct *ClientTorrent) getUrl(c Client) (string, error) { u, err := url.Parse(ct.Meta.Announce) if err != nil { return "", err @@ -130,10 +178,12 @@ func (c *Client) GetPeers(filename string) (PeerResponse, error) { return PeerResponse{}, err } - return PeerResponse{ + ct.PeerResponse = PeerResponse{ Interval: resp.Interval, Peers: DecodePeers([]byte(resp.Peers)), - }, err + } + + return ct.PeerResponse, nil } func (c *Client) Handshake(filename, peerAddr string) (*Peer, error) { @@ -187,3 +237,87 @@ func (c *Client) Handshake(filename, peerAddr string) (*Peer, error) { return peer, nil } + +func (f *FileResult) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(f.Data) + return int64(n), err +} + +func (f *FileWriter) Write(b []byte) (int, error) { + + f.ch <- FileResult{ + + Data: b, + + Piece: f.piece, + } + + return len(b), nil + +} + +func (ct *ClientTorrent) Download(out string) error { + pieces := ct.Meta.Pieces() + slog.Debug("Starting download", "pieceCnt", len(pieces), "peers", ct.PeerResponse) + pieceCh := make(chan Piece, len(pieces)) + fileCh := make(chan FileResult) + wg := sync.WaitGroup{} + wg.Add(len(pieces)) + + for _, piece := range pieces { + pieceCh <- piece + } + + for _, peer := range ct.Peers { + go peer.Download(pieceCh, fileCh) + } + + go func() { + wg.Wait() + close(fileCh) + close(pieceCh) + }() + + return ct.WriteFile(pieceCh, fileCh, out, &wg) +} + +func (ct *ClientTorrent) Close() (err error) { + for _, peer := range ct.Peers { + err = errors.Join(err, peer.Close()) + } + + return +} + +func (ct *ClientTorrent) WriteFile(pieceCh chan Piece, fileCh chan FileResult, out string, wg *sync.WaitGroup) error { + f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + + defer f.Close() + + for fr := range fileCh { + if _, err = f.Seek(int64(fr.Piece.Index*ct.Meta.Info.PieceLength), io.SeekStart); err != nil { + slog.Error("failed to seek", "err", err) + pieceCh <- fr.Piece + continue + } + + if _, err = fr.WriteTo(f); err != nil { + slog.Error("failed to write piece to file", "err", err) + pieceCh <- fr.Piece + continue + } + + if err = f.Sync(); err != nil { + slog.Error("failed to sync file to disk", "err", err) + continue + } + + slog.Debug("writing to file", "piece index", fr.Piece.Index) + wg.Done() + } + + return nil +} diff --git a/cmd/mybittorrent/main.go b/cmd/mybittorrent/main.go index 80a9b7d..6dbdda6 100644 --- a/cmd/mybittorrent/main.go +++ b/cmd/mybittorrent/main.go @@ -75,6 +75,9 @@ func main() { peerAddr := os.Args[3] peer, err := c.Handshake(fn, peerAddr) + + defer peer.Close() + if err != nil { panic(err) } @@ -91,6 +94,7 @@ func main() { } c := createClient(fn) + defer c.Close() pr, err := c.GetPeers(fn) if err != nil { panic(err) @@ -109,15 +113,39 @@ func main() { panic(fmt.Errorf("no peers found for file: %s", out)) } - defer peer.Close() + f, err := os.Create(out) + if err != nil { + panic(f) + } - err = peer.DownloadPiece(out, index) + defer f.Close() + + err = peer.DownloadPiece(f, c.Torrents[fn].Meta.Pieces()[index]) if err != nil { panic(err) } fmt.Printf("Piece %d downloaded to %s.", index, out) + case "download": + out := os.Args[3] + fn := os.Args[4] + c := createClient(fn) + + defer c.Close() + + ct, err := c.ConnectPeers(fn) + if err != nil { + panic(err) + } + + err = ct.Download(out) + if err != nil { + panic(err) + } + + fmt.Printf("Downloaded %s to %s", fn, out) + default: fmt.Println("Unknown command: " + command) os.Exit(1) diff --git a/cmd/mybittorrent/meta.go b/cmd/mybittorrent/meta.go index 27bcc12..11ade01 100644 --- a/cmd/mybittorrent/meta.go +++ b/cmd/mybittorrent/meta.go @@ -3,6 +3,7 @@ package main import ( "bytes" "crypto/sha1" + "fmt" "math" bencode "github.com/jackpal/bencode-go" @@ -22,6 +23,13 @@ type FileInfo struct { Pieces string `bencode:"pieces"` } +type Piece struct { + Index int + Len int + Hash string + Blocks []uint32 +} + func (m Meta) InfoHash() ([]byte, error) { sha := sha1.New() if err := bencode.Marshal(sha, m.Info); err != nil { @@ -40,16 +48,6 @@ func (m Meta) PieceHashes() []string { return hashes } -func (m Meta) CheckHash(pieceIndex int, data []byte) bool { - sha := sha1.New() - - if _, err := bytes.NewBuffer(data).WriteTo(sha); err != nil { - return false - } - - return bytes.Equal([]byte(m.Info.Pieces[pieceIndex*20:pieceIndex*20+20]), sha.Sum(nil)) -} - func (m Meta) PieceCount() int { return len(m.Info.Pieces) / 20 } @@ -84,3 +82,40 @@ func (m Meta) BlockLens(pieceIdx int) []uint32 { return blocks } + +func (p Piece) CheckHash(data []byte) error { + sha := sha1.New() + if _, err := bytes.NewBuffer(data).WriteTo(sha); err != nil { + return err + } + + exp, act := []byte(p.Hash), sha.Sum(nil) + if !bytes.Equal(exp, act) { + return fmt.Errorf("expected hash: %x, actual: %x", exp, act) + } + + return nil +} + +func (m Meta) Pieces() []Piece { + cnt := m.PieceCount() + pieces := make([]Piece, cnt) + + for i := 0; i < cnt; i++ { + p := Piece{ + Index: i, + Hash: m.Info.Pieces[i*20 : i*20+20], + Blocks: m.BlockLens(i), + } + + if i < cnt-1 { + p.Len = m.Info.PieceLength + } else { + p.Len = m.Info.Length - i*m.Info.PieceLength + } + + pieces[i] = p + } + + return pieces +} diff --git a/cmd/mybittorrent/peer.go b/cmd/mybittorrent/peer.go index 3b40493..aee5f26 100644 --- a/cmd/mybittorrent/peer.go +++ b/cmd/mybittorrent/peer.go @@ -7,8 +7,8 @@ import ( "errors" "fmt" "io" + "log/slog" "net" - "os" ) type HandshakeMessage []byte @@ -17,7 +17,7 @@ type Peer struct { conn net.Conn handshake HandshakeMessage ct *ClientTorrent - msgCh chan *IncomingMessage + unchoked bool } func (p *Peer) PeerIdHexString() string { @@ -37,31 +37,36 @@ func (p *Peer) Close() error { return nil } -func (p *Peer) DownloadPiece(outFile string, index int) error { - if index >= p.ct.Meta.PieceCount() { +func (p *Peer) DownloadPiece(w io.Writer, piece Piece) error { + if piece.Index >= p.ct.Meta.PieceCount() { return nil } var data = new(bytes.Buffer) - blockLens := p.ct.Meta.BlockLens(index) var blockIndex = 0 requestFn := func() error { - if blockIndex == len(blockLens) { + if blockIndex == len(piece.Blocks) { return nil } r := RequestPayload{ - Index: uint32(index), + Index: uint32(piece.Index), Begin: uint32(blockIndex * BlockSize), - Length: blockLens[blockIndex], + Length: piece.Blocks[blockIndex], } return p.WriteMessage(MessageTypeRequest, r.Bytes()) } + if p.unchoked { + if err := requestFn(); err != nil { + return err + } + } + for { - if blockIndex == len(blockLens) { + if blockIndex == len(piece.Blocks) { break } @@ -85,6 +90,7 @@ func (p *Peer) DownloadPiece(outFile string, index int) error { if err = requestFn(); err != nil { return err } + p.unchoked = true case MessageTypePiece: var block BlockPayload @@ -105,31 +111,24 @@ func (p *Peer) DownloadPiece(outFile string, index int) error { return err } + case MessageTypeChoke: + p.unchoked = false + default: return fmt.Errorf("unimplemented message type: %d", msg.MessageType) } } - if !p.ct.Meta.CheckHash(index, data.Bytes()) { - return fmt.Errorf("invalid hash value") - } - - out, err := os.Create(outFile) - if err != nil { - return errors.Join(fmt.Errorf("failed to create file"), err) - } + slog.Debug("piece downloaded", "piece index", piece.Index) - _, err = out.Write(data.Bytes()) - if err != nil { - return errors.Join(fmt.Errorf("failed to write file"), err) + if err := piece.CheckHash(data.Bytes()); err != nil { + return fmt.Errorf("invalid hash value: %v", err) } - err = out.Close() - if err != nil { - return errors.Join(fmt.Errorf("failed to close file"), err) - } - - return nil + slog.Debug("hash ok", "piece index", piece.Index) + _, err := w.Write(data.Bytes()) + slog.Debug("written to writer", "piece index", piece.Index) + return err } func (p *Peer) ReadMessage() (*IncomingMessage, error) { @@ -179,3 +178,21 @@ func (p *Peer) WriteMessage(t MessageType, payload []byte) error { _, err := msg.Write(payload) return err } + +func (p *Peer) Download(pieceCh chan Piece, fileCh chan FileResult) { + for piece := range pieceCh { + fw := FileWriter{ + ch: fileCh, + piece: piece, + } + + err := p.DownloadPiece(&fw, piece) + if err != nil { + slog.Error("failed to download piece", "piece", piece, "err", err) + pieceCh <- piece + continue + } + + slog.Debug("downloaded piece", "piece index", piece.Index, "left", len(pieceCh)) + } +} |