From 210fb1e02453413d1ce070b70c850807286a1a7a Mon Sep 17 00:00:00 2001 From: jet2tlf Date: Mon, 3 Jun 2024 15:31:42 -0300 Subject: codecrafters submit [skip ci] --- cmd/mybittorrent/peer.go | 69 ++++++++++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 26 deletions(-) (limited to 'cmd/mybittorrent/peer.go') diff --git a/cmd/mybittorrent/peer.go b/cmd/mybittorrent/peer.go index 3b40493..aee5f26 100644 --- a/cmd/mybittorrent/peer.go +++ b/cmd/mybittorrent/peer.go @@ -7,8 +7,8 @@ import ( "errors" "fmt" "io" + "log/slog" "net" - "os" ) type HandshakeMessage []byte @@ -17,7 +17,7 @@ type Peer struct { conn net.Conn handshake HandshakeMessage ct *ClientTorrent - msgCh chan *IncomingMessage + unchoked bool } func (p *Peer) PeerIdHexString() string { @@ -37,31 +37,36 @@ func (p *Peer) Close() error { return nil } -func (p *Peer) DownloadPiece(outFile string, index int) error { - if index >= p.ct.Meta.PieceCount() { +func (p *Peer) DownloadPiece(w io.Writer, piece Piece) error { + if piece.Index >= p.ct.Meta.PieceCount() { return nil } var data = new(bytes.Buffer) - blockLens := p.ct.Meta.BlockLens(index) var blockIndex = 0 requestFn := func() error { - if blockIndex == len(blockLens) { + if blockIndex == len(piece.Blocks) { return nil } r := RequestPayload{ - Index: uint32(index), + Index: uint32(piece.Index), Begin: uint32(blockIndex * BlockSize), - Length: blockLens[blockIndex], + Length: piece.Blocks[blockIndex], } return p.WriteMessage(MessageTypeRequest, r.Bytes()) } + if p.unchoked { + if err := requestFn(); err != nil { + return err + } + } + for { - if blockIndex == len(blockLens) { + if blockIndex == len(piece.Blocks) { break } @@ -85,6 +90,7 @@ func (p *Peer) DownloadPiece(outFile string, index int) error { if err = requestFn(); err != nil { return err } + p.unchoked = true case MessageTypePiece: var block BlockPayload @@ -105,31 +111,24 @@ func (p *Peer) DownloadPiece(outFile string, index int) error { return err } + case MessageTypeChoke: + p.unchoked = false + default: return fmt.Errorf("unimplemented message type: %d", msg.MessageType) } } - if !p.ct.Meta.CheckHash(index, data.Bytes()) { - return fmt.Errorf("invalid hash value") - } - - out, err := os.Create(outFile) - if err != nil { - return errors.Join(fmt.Errorf("failed to create file"), err) - } + slog.Debug("piece downloaded", "piece index", piece.Index) - _, err = out.Write(data.Bytes()) - if err != nil { - return errors.Join(fmt.Errorf("failed to write file"), err) + if err := piece.CheckHash(data.Bytes()); err != nil { + return fmt.Errorf("invalid hash value: %v", err) } - err = out.Close() - if err != nil { - return errors.Join(fmt.Errorf("failed to close file"), err) - } - - return nil + slog.Debug("hash ok", "piece index", piece.Index) + _, err := w.Write(data.Bytes()) + slog.Debug("written to writer", "piece index", piece.Index) + return err } func (p *Peer) ReadMessage() (*IncomingMessage, error) { @@ -179,3 +178,21 @@ func (p *Peer) WriteMessage(t MessageType, payload []byte) error { _, err := msg.Write(payload) return err } + +func (p *Peer) Download(pieceCh chan Piece, fileCh chan FileResult) { + for piece := range pieceCh { + fw := FileWriter{ + ch: fileCh, + piece: piece, + } + + err := p.DownloadPiece(&fw, piece) + if err != nil { + slog.Error("failed to download piece", "piece", piece, "err", err) + pieceCh <- piece + continue + } + + slog.Debug("downloaded piece", "piece index", piece.Index, "left", len(pieceCh)) + } +} -- cgit v1.2.3