diff --git a/pkg/add/add.go b/pkg/add/add.go new file mode 100644 index 0000000..447a184 --- /dev/null +++ b/pkg/add/add.go @@ -0,0 +1,40 @@ +package add + +import ( + "fmt" + "io/ioutil" + + "github.com/urfave/cli/v2" + "github.com/yusufcanb/tlm/pkg/chroma" +) + +func Command() *cli.Command { + return &cli.Command{ + Name: "add", + Usage: "Adds a document to the ChromaDB collection.", + ArgsUsage: "", + Action: func(c *cli.Context) error { + if c.NArg() == 0 { + return fmt.Errorf("a file path is required") + } + + filePath := c.Args().First() + content, err := ioutil.ReadFile(filePath) + if err != nil { + return fmt.Errorf("failed to read file: %w", err) + } + + chromaClient := chroma.NewChromaClient("http://localhost:8000") + err = chromaClient.Add("tlm-collection", &chroma.AddRequest{ + Documents: []string{string(content)}, + IDs: []string{filePath}, + }) + if err != nil { + return fmt.Errorf("failed to add document: %w", err) + } + + fmt.Printf("Document '%s' added successfully.\n", filePath) + return nil + }, + } +} diff --git a/pkg/app/app.go b/pkg/app/app.go index 31c4dc6..37e4453 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -7,6 +7,7 @@ import ( ollama "github.com/jmorganca/ollama/api" "github.com/yusufcanb/tlm/pkg/ask" + "github.com/yusufcanb/tlm/pkg/add" "github.com/yusufcanb/tlm/pkg/config" "github.com/yusufcanb/tlm/pkg/explain" "github.com/yusufcanb/tlm/pkg/suggest" @@ -49,6 +50,7 @@ func New(version, buildSha string) *TlmApp { return cli.ShowAppHelp(c) }, Commands: []*cli.Command{ + add.Command(), ask.Command(), sug.Command(), exp.Command(), diff --git a/pkg/ask/cli.go b/pkg/ask/cli.go index 9b67f79..2ffccd1 100644 --- a/pkg/ask/cli.go +++ b/pkg/ask/cli.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/viper" "github.com/urfave/cli/v2" - "github.com/yusufcanb/tlm/pkg/packer" + "github.com/yusufcanb/tlm/pkg/chroma" "github.com/yusufcanb/tlm/pkg/rag" ) @@ -36,40 +36,14 @@ func (a *Ask) beforeAction(c *cli.Context) error { func (a *Ask) action(c *cli.Context) error { isInteractive := c.Bool("interactive") - contextDir := c.Path("context") - - var chatContext string // chat context var numCtx int = 1024 * 8 // num_ctx in Ollama API - if contextDir != "" { - includePatterns := c.StringSlice("include") - excludePatterns := c.StringSlice("exclude") - // fmt.Printf("include=%v, exclude=%v\n\n", includePatterns, excludePatterns) - - // Pack files under the context directory - packer := packer.New() - res, err := packer.Pack(contextDir, includePatterns, excludePatterns) - if err != nil { - return err - } - - // Sort the files by the number of tokens - packer.PrintTopFiles(res, 5) - - // Print the context summary - packer.PrintContextSummary(res) - - // Render the packer result - chatContext, err = packer.Render(res) - if err != nil { - return err - } - } - fmt.Printf("\nšŸ¤– %s\n───────────────────\n", a.model) + chromaClient := chroma.NewChromaClient("http://localhost:8000") + prompt := c.Args().First() - rag := rag.NewRAGChat(a.api, chatContext, a.model) + rag := rag.NewRAGChat(a.api, chromaClient, "", a.model) _, err := rag.Send(prompt, numCtx) if err != nil { return err @@ -115,21 +89,6 @@ func (a *Ask) Command() *cli.Command { Before: a.beforeAction, After: a.afterAction, Flags: []cli.Flag{ - &cli.PathFlag{ - Name: "context", - Aliases: []string{"c"}, - Usage: "context directory path", - }, - &cli.StringSliceFlag{ - Name: "include", - Aliases: []string{"i"}, - Usage: "include patterns. e.g. --include=*.txt or --include=*.txt,*.md", - }, - &cli.StringSliceFlag{ - Name: "exclude", - Aliases: []string{"e"}, - Usage: "exclude patterns. e.g. --exclude=**/*_test.go or --exclude=*.pyc,*.pyd", - }, &cli.BoolFlag{ Name: "interactive", Aliases: []string{"it"}, diff --git a/pkg/chroma/chroma.go b/pkg/chroma/chroma.go new file mode 100644 index 0000000..7c2e884 --- /dev/null +++ b/pkg/chroma/chroma.go @@ -0,0 +1,91 @@ +package chroma + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" +) + +type ChromaClient struct { + baseURL string + client *http.Client +} + +func NewChromaClient(baseURL string) *ChromaClient { + return &ChromaClient{ + baseURL: baseURL, + client: &http.Client{}, + } +} + +type AddRequest struct { + Documents []string `json:"documents"` + IDs []string `json:"ids"` +} + +type QueryRequest struct { + QueryTexts []string `json:"query_texts"` + NResults int `json:"n_results"` +} + +type QueryResponse struct { + Documents [][]string `json:"documents"` +} + +func (c *ChromaClient) Add(collectionName string, req *AddRequest) error { + url := fmt.Sprintf("%s/api/v1/collections/%s/add", c.baseURL, collectionName) + jsonBody, err := json.Marshal(req) + if err != nil { + return err + } + + httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + return err + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to add documents: %s", resp.Status) + } + + return nil +} + +func (c *ChromaClient) Query(collectionName string, req *QueryRequest) (*QueryResponse, error) { + url := fmt.Sprintf("%s/api/v1/collections/%s/query", c.baseURL, collectionName) + jsonBody, err := json.Marshal(req) + if err != nil { + return nil, err + } + + httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to query documents: %s", resp.Status) + } + + var queryResp QueryResponse + if err := json.NewDecoder(resp.Body).Decode(&queryResp); err != nil { + return nil, err + } + + return &queryResp, nil +} diff --git a/pkg/rag/rag.go b/pkg/rag/rag.go index 129ee1e..c47e452 100644 --- a/pkg/rag/rag.go +++ b/pkg/rag/rag.go @@ -6,11 +6,14 @@ import ( "fmt" "strings" + "github.com/yusufcanb/tlm/pkg/chroma" + ollama "github.com/jmorganca/ollama/api" ) type RAGChat struct { - api *ollama.Client + api *ollama.Client + chromaClient *chroma.ChromaClient model string context string @@ -28,10 +31,28 @@ func (r *RAGChat) Send(message string, numCtx int) (string, error) { Content: "You are a software engineer and a helpful assistant.", }) + // query chroma for context + queryResp, err := r.chromaClient.Query("tlm-collection", &chroma.QueryRequest{ + QueryTexts: []string{message}, + NResults: 5, + }) + if err != nil { + return "", fmt.Errorf("error querying chroma: %s", err.Error()) + } + + // build context from chroma response + var context strings.Builder + for _, doc := range queryResp.Documents { + for _, d := range doc { + context.WriteString(d) + context.WriteString("\n") + } + } + // Add context as the first message r.history = append(r.history, ollama.Message{ Role: "user", - Content: r.context + "\n" + message, + Content: context.String() + "\n" + message, }) } else { // if history exists, add the new message to the history r.history = append(r.history, ollama.Message{ @@ -66,11 +87,12 @@ func (r *RAGChat) Send(message string, numCtx int) (string, error) { return "", nil } -func NewRAGChat(api *ollama.Client, context string, model string) *RAGChat { +func NewRAGChat(api *ollama.Client, chromaClient *chroma.ChromaClient, context string, model string) *RAGChat { return &RAGChat{ - api: api, - model: model, - context: context, - history: make([]ollama.Message, 0), + api: api, + chromaClient: chromaClient, + model: model, + context: context, + history: make([]ollama.Message, 0), } }