diff options
Diffstat (limited to 'cmd/mybittorrent')
| -rw-r--r-- | cmd/mybittorrent/client.go | 51 | ||||
| -rw-r--r-- | cmd/mybittorrent/main.go | 54 | ||||
| -rw-r--r-- | cmd/mybittorrent/message.go | 76 | ||||
| -rw-r--r-- | cmd/mybittorrent/meta.go | 49 | ||||
| -rw-r--r-- | cmd/mybittorrent/peer.go | 181 |
5 files changed, 376 insertions, 35 deletions
diff --git a/cmd/mybittorrent/client.go b/cmd/mybittorrent/client.go index ef3d860..8b3aac8 100644 --- a/cmd/mybittorrent/client.go +++ b/cmd/mybittorrent/client.go @@ -3,7 +3,6 @@ package main import ( "bytes" "encoding/binary" - "encoding/hex" "fmt" "io" "log/slog" @@ -19,7 +18,7 @@ import ( type Client struct { PeerId string Port int - Torrents map[string]ClientTorrent + Torrents map[string]*ClientTorrent } type ClientTorrent struct { @@ -34,17 +33,15 @@ type PeerResponse struct { Peers []string } -type HandshakeMessage []byte - func NewClient(peerId string, port int) *Client { return &Client{ PeerId: peerId, Port: port, - Torrents: make(map[string]ClientTorrent), + Torrents: make(map[string]*ClientTorrent), } } -func (c *Client) AddTorrentFile(filename string) (ClientTorrent, error) { +func (c *Client) AddTorrentFile(filename string) error { f, err := os.Open(filename) defer func(f *os.File) { err := f.Close() @@ -60,19 +57,17 @@ func (c *Client) AddTorrentFile(filename string) (ClientTorrent, error) { var meta Meta if err = bencode.Unmarshal(f, &meta); err != nil { - return ClientTorrent{}, err + return err } - t := ClientTorrent{ + c.Torrents[filename] = &ClientTorrent{ Meta: meta, Uploaded: 0, Downloaded: 0, Left: meta.Info.Length, } - c.Torrents[filename] = t - - return t, nil + return nil } func (ct ClientTorrent) getUrl(c Client) (string, error) { @@ -141,13 +136,7 @@ func (c *Client) GetPeers(filename string) (PeerResponse, error) { }, err } -func (m HandshakeMessage) PeerIdHex() string { - - return hex.EncodeToString(m[48:]) - -} - -func (c *Client) Handshake(filename, peerAddr string) (HandshakeMessage, error) { +func (c *Client) Handshake(filename, peerAddr string) (*Peer, error) { ct, ok := c.Torrents[filename] if !ok { return nil, fmt.Errorf("missing torrent file: %s", filename) @@ -167,12 +156,6 @@ func (c *Client) Handshake(filename, peerAddr string) (HandshakeMessage, error) buf.WriteString(c.PeerId) conn, err := net.Dial("tcp", peerAddr) - defer func(conn net.Conn) { - err := conn.Close() - if err != nil { - slog.Error("failed to close connection", "peerAddr", peerAddr) - } - }(conn) if err != nil { return nil, err @@ -184,7 +167,23 @@ func (c *Client) Handshake(filename, peerAddr string) (HandshakeMessage, error) } respBuf := make([]byte, 68) - _, err = io.LimitReader(conn, 68).Read(respBuf) - return respBuf, nil + _, err = io.ReadFull(conn, respBuf) + + peer := &Peer{ + conn: conn, + handshake: respBuf, + ct: ct, + } + + if !bytes.Equal(peer.InfoHash(), hash) { + err := conn.Close() + if err != nil { + slog.Error("Failed to close peer connection", "remoteAddr", peer.conn.RemoteAddr()) + } + + return nil, fmt.Errorf("invalid info hash from peer: %x, addr: %s", peer.InfoHash(), peerAddr) + } + + return peer, nil } diff --git a/cmd/mybittorrent/main.go b/cmd/mybittorrent/main.go index f000aa5..80a9b7d 100644 --- a/cmd/mybittorrent/main.go +++ b/cmd/mybittorrent/main.go @@ -4,17 +4,16 @@ import ( "encoding/json" "fmt" "os" + "strconv" "strings" bencode "github.com/jackpal/bencode-go" // Available if you need it! ) -func createClient() *Client { - fn := os.Args[2] +func createClient(fn string) *Client { c := NewClient("00112233445566778899", 6881) - _, err := c.AddTorrentFile(fn) - if err != nil { + if err := c.AddTorrentFile(fn); err != nil { panic(err) } @@ -38,7 +37,7 @@ func main() { case "info": fn := os.Args[2] - c := createClient() + c := createClient(fn) meta := c.Torrents[fn].Meta fmt.Printf("Tracker URL: %s\n", meta.Announce) fmt.Printf("Length: %d\n", meta.Info.Length) @@ -59,7 +58,7 @@ func main() { case "peers": fn := os.Args[2] - c := createClient() + c := createClient(fn) pr, err := c.GetPeers(fn) if err != nil { @@ -72,15 +71,52 @@ func main() { case "handshake": fn := os.Args[2] - c := createClient() + c := createClient(fn) peerAddr := os.Args[3] - hs, err := c.Handshake(fn, peerAddr) + peer, err := c.Handshake(fn, peerAddr) if err != nil { panic(err) } - fmt.Printf("Peer ID: %s\n", hs.PeerIdHex()) + fmt.Printf("Peer ID: %s\n", peer.PeerIdHexString()) + + case "download_piece": + out := os.Args[3] + fn := os.Args[4] + + index, err := strconv.Atoi(os.Args[5]) + if err != nil { + panic(err) + } + + c := createClient(fn) + pr, err := c.GetPeers(fn) + if err != nil { + panic(err) + } + + var peer *Peer + + for _, peerAddr := range pr.Peers { + peer, err = c.Handshake(fn, peerAddr) + if err != nil { + continue + } + } + + if peer == nil { + panic(fmt.Errorf("no peers found for file: %s", out)) + } + + defer peer.Close() + + err = peer.DownloadPiece(out, index) + if err != nil { + panic(err) + } + + fmt.Printf("Piece %d downloaded to %s.", index, out) default: fmt.Println("Unknown command: " + command) diff --git a/cmd/mybittorrent/message.go b/cmd/mybittorrent/message.go new file mode 100644 index 0000000..9439bd5 --- /dev/null +++ b/cmd/mybittorrent/message.go @@ -0,0 +1,76 @@ +package main + +import ( + "encoding/binary" + "io" +) + +type MessageType byte + +type BlockPayload struct { + Index uint32 + Begin uint32 + Block []byte +} + +type IncomingMessage struct { + Len uint32 + MessageType MessageType + Payload []byte +} + +type OutgoingMessage struct { + MessageType MessageType + Writer io.Writer +} + +type RequestPayload struct { + Index uint32 + Begin uint32 + Length uint32 +} + +const ( + MessageTypeChoke MessageType = iota + MessageTypeUnchoke + MessageTypeInterested + MessageTypeNotInterested + MessageTypeHave + MessageTypeBitfield + MessageTypeRequest + MessageTypePiece + MessageTypeCancel +) + +func (o *OutgoingMessage) Write(b []byte) (int, error) { + msgLen := 1 + len(b) + payloadBuff := make([]byte, msgLen+4) + binary.BigEndian.PutUint32(payloadBuff[0:4], uint32(msgLen)) + payloadBuff[4] = byte(o.MessageType) + + if msgLen > 1 { + copy(payloadBuff[5:], b) + } + + return o.Writer.Write(payloadBuff) +} + +func (r RequestPayload) Bytes() []byte { + buf := make([]byte, 12) + binary.BigEndian.PutUint32(buf[0:4], r.Index) + binary.BigEndian.PutUint32(buf[4:8], r.Begin) + binary.BigEndian.PutUint32(buf[8:12], r.Length) + return buf +} + +func (p *BlockPayload) Write(b []byte) (int, error) { + p.Index = binary.BigEndian.Uint32(b[0:4]) + p.Begin = binary.BigEndian.Uint32(b[4:8]) + p.Block = b[8:] + return len(b), nil +} + +func (p *BlockPayload) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(p.Block) + return int64(n), err +} diff --git a/cmd/mybittorrent/meta.go b/cmd/mybittorrent/meta.go index 34369c4..27bcc12 100644 --- a/cmd/mybittorrent/meta.go +++ b/cmd/mybittorrent/meta.go @@ -1,11 +1,15 @@ package main import ( + "bytes" "crypto/sha1" + "math" bencode "github.com/jackpal/bencode-go" ) +const BlockSize = 16 * 1024 + type Meta struct { Announce string `bencode:"announce"` Info FileInfo `bencode:"info"` @@ -35,3 +39,48 @@ func (m Meta) PieceHashes() []string { return hashes } + +func (m Meta) CheckHash(pieceIndex int, data []byte) bool { + sha := sha1.New() + + if _, err := bytes.NewBuffer(data).WriteTo(sha); err != nil { + return false + } + + return bytes.Equal([]byte(m.Info.Pieces[pieceIndex*20:pieceIndex*20+20]), sha.Sum(nil)) +} + +func (m Meta) PieceCount() int { + return len(m.Info.Pieces) / 20 +} + +func (m Meta) PieceLens() []int { + pieceCnt := m.PieceCount() + pieces := make([]int, pieceCnt) + + for i := 0; i < pieceCnt; i++ { + if i < pieceCnt-1 { + pieces[i] = m.Info.PieceLength + } else { + pieces[i] = m.Info.Length - i*m.Info.PieceLength + } + } + + return pieces +} + +func (m Meta) BlockLens(pieceIdx int) []uint32 { + pieceLen := m.PieceLens()[pieceIdx] + blockCnt := int(math.Ceil(float64(pieceLen) / float64(BlockSize))) + blocks := make([]uint32, blockCnt) + + for i := 0; i < blockCnt; i++ { + if i < blockCnt-1 { + blocks[i] = uint32(BlockSize) + } else { + blocks[i] = uint32(pieceLen - i*BlockSize) + } + } + + return blocks +} diff --git a/cmd/mybittorrent/peer.go b/cmd/mybittorrent/peer.go new file mode 100644 index 0000000..3b40493 --- /dev/null +++ b/cmd/mybittorrent/peer.go @@ -0,0 +1,181 @@ +package main + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "os" +) + +type HandshakeMessage []byte + +type Peer struct { + conn net.Conn + handshake HandshakeMessage + ct *ClientTorrent + msgCh chan *IncomingMessage +} + +func (p *Peer) PeerIdHexString() string { + return hex.EncodeToString(p.handshake[48:]) +} + +func (p *Peer) InfoHash() []byte { + return p.handshake[28:48] +} + +func (p *Peer) Close() error { + err := p.conn.Close() + if err != nil { + return errors.Join(fmt.Errorf("failed to close peer connection: %s", p.conn.RemoteAddr()), err) + } + + return nil +} + +func (p *Peer) DownloadPiece(outFile string, index int) error { + if index >= p.ct.Meta.PieceCount() { + return nil + } + + var data = new(bytes.Buffer) + blockLens := p.ct.Meta.BlockLens(index) + var blockIndex = 0 + + requestFn := func() error { + if blockIndex == len(blockLens) { + return nil + } + + r := RequestPayload{ + Index: uint32(index), + Begin: uint32(blockIndex * BlockSize), + Length: blockLens[blockIndex], + } + + return p.WriteMessage(MessageTypeRequest, r.Bytes()) + } + + for { + if blockIndex == len(blockLens) { + break + } + + msg, err := p.ReadMessage() + if err != nil { + return err + } + + switch msg.MessageType { + case MessageTypeBitfield: + if err != nil { + return err + } + + err = p.WriteMessage(MessageTypeInterested, nil) + if err != nil { + return err + } + + case MessageTypeUnchoke: + if err = requestFn(); err != nil { + return err + } + + case MessageTypePiece: + var block BlockPayload + + _, err = block.Write(msg.Payload) + if err != nil { + return err + } + + _, err = data.Write(block.Block) + if err != nil { + return err + } + + blockIndex++ + + if err = requestFn(); err != nil { + return err + } + + default: + return fmt.Errorf("unimplemented message type: %d", msg.MessageType) + } + } + + if !p.ct.Meta.CheckHash(index, data.Bytes()) { + return fmt.Errorf("invalid hash value") + } + + out, err := os.Create(outFile) + if err != nil { + return errors.Join(fmt.Errorf("failed to create file"), err) + } + + _, err = out.Write(data.Bytes()) + if err != nil { + return errors.Join(fmt.Errorf("failed to write file"), err) + } + + err = out.Close() + if err != nil { + return errors.Join(fmt.Errorf("failed to close file"), err) + } + + return nil +} + +func (p *Peer) ReadMessage() (*IncomingMessage, error) { + lenBuf := make([]byte, 4) + + _, err := p.conn.Read(lenBuf) + if err != nil { + return nil, errors.Join(fmt.Errorf("failed to read message length"), err) + } + + msgLen := binary.BigEndian.Uint32(lenBuf) - 1 + typeBuf := make([]byte, 1) + + _, err = p.conn.Read(typeBuf) + if err != nil { + return nil, errors.Join(fmt.Errorf("failed to read message type"), err) + } + + msgType := MessageType(typeBuf[0]) + if msgLen == 0 { + return &IncomingMessage{ + Len: msgLen, + MessageType: msgType, + }, nil + } + + payloadBuf := make([]byte, msgLen) + + _, err = io.ReadFull(p.conn, payloadBuf) + if err != nil { + return nil, errors.Join(fmt.Errorf("failed to read message payload"), err) + } + + return &IncomingMessage{ + Len: msgLen, + MessageType: msgType, + Payload: payloadBuf, + }, nil +} + +func (p *Peer) WriteMessage(t MessageType, payload []byte) error { + msg := &OutgoingMessage{ + MessageType: t, + Writer: p.conn, + } + + _, err := msg.Write(payload) + return err +} |