-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathmodel.py
More file actions
49 lines (40 loc) · 1.54 KB
/
model.py
File metadata and controls
49 lines (40 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
This module provides a function to get a model based on the configuration.
"""
import os
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from src.lib.state import AgentState
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
def get_model(state: AgentState) -> BaseChatModel:
"""
Get a model based on the environment variable.
"""
state_model = state.get("model", "openai")
model = os.getenv("MODEL", state_model)
if model == "openai":
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY environment variable is not set")
return ChatOpenAI(temperature=0, model="gpt-4o-mini", api_key=OPENAI_API_KEY)
if model == "anthropic":
if not ANTHROPIC_API_KEY:
raise ValueError("ANTHROPIC_API_KEY environment variable is not set")
return ChatAnthropic(
temperature=0,
model_name="claude-3-5-sonnet-20240620",
timeout=None,
stop=None,
)
if model == "google_genai":
if not GOOGLE_API_KEY:
raise ValueError("GOOGLE_API_KEY environment variable is not set")
return ChatGoogleGenerativeAI(
temperature=0,
model="gemini-1.5-pro",
api_key=GOOGLE_API_KEY,
)
raise ValueError("Invalid model specified")